import numpy as np
from optimizer import inexact_newton, localSGD, fedAC
from function import logistic
from data import loada9a
from run_help import tune_lr


def tune(F, mu, M, lrs, T, R_values):
    store = {"Newton_w_M": [], "LSGD_w_M": [],
             "MBSGD_w_M": [], "FEDAC1": [], "FEDAC2": [], "Newton": [], "LSGD": [], "MBSGD": []}

    for R in R_values:
        K = int(T/R)
        # Tuning Inexact Newton
        # args = {"F": F, "T": R, "alpha": 1.25, "M": M, "K": K, "R": 1,
        #                 "lr": 0, "momentum": 0.9, "mu": 0, "damp": True}
        # min_lr, min_loss, min_W, min_losses = tune_lr(
        #     inexact_newton, "Inexact Newton w/ Momentum", lrs, **args)
        # store["Newton_w_M"].append((min_lr, min_loss))

        # args = {"F": F, "T": R, "alpha": 1.25, "M": M, "K": K, "R": 1,
        #                 "lr": 0, "momentum": 0, "mu": 0, "damp": True}
        # min_lr, min_loss, min_W, min_losses = tune_lr(
        #     inexact_newton, "Inexact Newton", lrs, **args)
        # store["Newton"].append((min_lr, min_loss))

        # Tuning Local SGD
        # args = {"F": F, "M": M, "K": K, "R": R,
        #                 "mu": 0, "u": None, "lr": 0, "momentum": 0.9, "forHVP": False}
        # min_lr, min_loss, min_W, min_losses = tune_lr(
        #     localSGD, "Local SGD w/ Momentum", lrs, **args)
        # store["LSGD_w_M"].append((min_lr, min_loss))

        # args = {"F": F, "M": M, "K": K, "R": R,
        #                 "mu": 0, "u": None, "lr": 0, "momentum": 0, "forHVP": False}
        # min_lr, min_loss, min_W, min_losses = tune_lr(
        #     localSGD, "Local SGD", lrs, **args)
        # store["LSGD"].append((min_lr, min_loss))

        # Tuning MBSGD
        # args = {"F": F, "M": M*K, "K": 1, "R": R,
        #                 "mu": 0, "u": None, "lr": 0, "momentum": 0.9, "forHVP": False}
        # min_lr, min_loss, min_W, min_losses = tune_lr(
        #     localSGD, "Mini-batch SGD w/ Momentum", lrs, **args)
        # store["MBSGD_w_M"].append((min_lr, min_loss))

        # args = {"F": F, "M": M*K, "K": 1, "R": R,
        #                 "mu": 0, "u": None, "lr": 0, "momentum": 0, "forHVP": False}
        # min_lr, min_loss, min_W, min_losses = tune_lr(
        #     localSGD, "Mini-batch SGD", lrs, **args)
        # store["MBSGD"].append((min_lr, min_loss))

        # Tuning FedAC-1
        args = {"F": F, "M": M, "K": K, "R": R,
                        "mu": mu, "lr": 0, "ver": 1, "u": None}
        min_lr, min_loss, min_W, min_losses = tune_lr(
            fedAC, "FedAC-1", lrs, **args)
        store["FEDAC1"].append((min_lr, min_loss))

        # Tuning FedAC-2
        args = {"F": F, "M": M, "K": K, "R": R,
                        "mu": mu, "lr": 0, "ver": 2, "u": None}
        min_lr, min_loss, min_W, min_losses = tune_lr(
            fedAC, "FedAC-2", lrs, **args)
        store["FEDAC2"].append((min_lr, min_loss))

    return store


F = logistic
mu_values = [1e-2, 1e-4, 1e-5]
M_values = [20, 50, 100, 200]

lrs = [0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1, 2, 5, 10]

T = 100  # i.e. K*R = 100
R_values = [1, 5, 10, 25, 50, 100]

for mu in mu_values:
    for M in M_values:
        print(f"[*] Running for mu = {mu} and M = {M}")
        store = tune(F, mu, M, lrs, T, R_values)
        np.save(f"results/store_{mu}_{M}.npy", store)
